import torch
import torch.nn as nn


class LLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x, device):
        # 初始化的隐藏元和记忆元,通常它们的维度是一样的
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)  # x.size(0)是batch_size
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size)

        # print(out.shape)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out


if __name__ == '__main__':
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    inputs = torch.ones((10, 1, 69, 6))
    inputs = inputs.reshape(-1, 6, 69)

    lLSTM = LLSTM(69, 128, 2, 2)
    outputs = lLSTM(inputs, device)
    print(outputs.shape)




